#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import json
import glob
import itertools
import sys
import time
import re
import copy
from core.utils.exec_cmd import ExecCmd
from core.utils.memory import get_memory
from core.utils.log import logger
from core.utils.path import create_path, create_path_force, destroy_path,\
    create_link_force, copy_file
from multiprocessing import cpu_count
from core.utils.param import ParamConvert
from core.utils.configure import TestConfig

major_version = sys.version_info[0]
if major_version == 2:
    from urllib import quote
else:
    from urllib.request import quote


def init_test_info(test, appconf):
    """
    初始化测试套件信息
    """
    testconf = TestConfig(test)
    testconf.subcommand = appconf.subcommand
    testconf.source_path = os.path.join(appconf.test_source_path, test)
    testconf.run_path = os.path.join(appconf.test_run_path, test)
    testconf.cache_path = os.path.join(appconf.test_cache_path, test)
    testconf.build_path = os.path.join(appconf.test_build_path, test)
    testconf.bash_lib_path = appconf.bash_lib_path
    testconf.result_path = os.path.join(appconf.test_result_path, test)
    testconf.conf_path = appconf.test_conf_path
    testconf.yaml_path = os.path.join(appconf.test_yaml_path, test + '.yaml')
    testconf.shell_map = appconf.test_interface_shell_map
    if 'timeout' in appconf.parser:
        testconf.timeout = appconf.parser['timeout']
    else:
        testconf.timeout = None
    testconf.git_user = appconf.git_user
    # pdb.set_trace()
    testconf.git_password = quote(appconf.git_password)

    testconf.tone_path = appconf.tone_path
    testconf.tone_bin_path = os.path.join(appconf.tone_path, 'bin')
    testconf.is_quiet = appconf.is_quiet
    if 'keep' in appconf.parser:
        testconf.keep_build = appconf.parser['keep']
    else:
        testconf.keep_build = False
    return testconf


def read_conf(fn, suite='-'):
    value = {}
    indvalue = []
    converter = ParamConvert()
    with open(fn) as fh:
        data = fh.readlines()
        while data and data[-1].strip() == "":
            data.pop()
        while data and data[0].strip() == "":
            data.pop(0)
        if data:
            fields = data[0].split()
            for i in range(1, len(data)):
                tmp = data[i].split()
                recorder = {}
                origin_value = {}
                for j in range(len(fields)):
                    if tmp[j] in ('Na', 'na', 'NA'):
                        recorder[fields[j]] = str(None)
                        origin_value[fields[j]] = 'Na'
                    else:
                        recorder[fields[j]] = str(
                            converter[fields[j]].get(tmp[j]))
                        origin_value[fields[j]] = tmp[j]
                if 'testconf' not in fields:
                    testconf = suite + ":" + \
                        ",".join([f + "=" + origin_value[f]
                                 for f in fields if origin_value[f] != "Na"])
                    if testconf == suite + ':':
                        testconf = suite + ':default'
                    recorder['testconf'] = testconf
                value[recorder['testconf']] = recorder
                indvalue.append(recorder)
        else:
            testconf = suite + ':default'
            recorder = {
                'testconf': testconf,
            }
            value = {
                testconf: recorder,
            }
            indvalue = [
                recorder,
            ]
            fields = []
            pass
    return value, indvalue, fields


class TestCommand(ExecCmd):
    """
    继承ExecCmd类封装用于执行测试的bash命令函数
    """

    def __init__(
            self, command, workdir=None,
            env=None, timeout=None,
            logpath=None, save2mem=False,
            is_quiet=True):
        if workdir:
            os.chdir(workdir)
        super(TestCommand, self).__init__(
            command, env=env, timeout=timeout,
            logpath=logpath, save2mem=save2mem,
            is_quiet=is_quiet
        )


class TestCommands(object):
    """
    bash命令执行管理类
    """

    info_commands = [
        "df -mP",
        "dmesg",
        "uname -a",
        "lspci -vvnn",
        "gcc --version",
        "ld --version",
        "hostname",
        "uptime",
        "dmidecode",
        "ifconfig -a",
        "ip link",
        "numactl --hardware show",
        "lscpu",
        "fdisk -l",
        "sysctl -a",
    ]

    def __init__(self, config=None):
        self.current_executor = None
        self.result = None
        self.config = config
        self.update_log_path()
        self.stoptime = self.config.stoptime
        if "rpm -qa" not in self.info_commands and "dpkg -l" not in self.info_commands:
            if os.system("which yum > /dev/null 2>&1") == 0:
                self.info_commands.append("rpm -qa")
            elif os.system("which apt > /dev/null 2>&1") == 0:
                self.info_commands.append("dpkg -l")
            else:
                logger.error("Unsupported Distribution")

    def update_log_path(self):
        """
        更新测试日志文件位置:
            如果是download则日志文件在: 测试结果目录/套件名称
            如果是测试则在: 测试结果目录/测试套件名称/执行次数
        """
        self.logfile = os.path.join(
            self._get_current_result_path(), 'stdout.log')
        self.errorlog = os.path.join(
            self._get_current_result_path(), 'stderr.log')
        self.parselog = os.path.join(
            self._get_current_result_path(), 'result.json')

    def _get_current_result_path(self):
        current = self.config.result_runtime_path
        if os.path.isdir(current):
            return current
        else:
            return self.config.result_path

    def _save2log(self):
        """
        保存命令输出到日志文件中
        """
        with open(self.logfile, "a+") as fh:
            fh.write(self.result.stdout)
        with open(self.errorlog, "a+") as fh:
            fh.write(self.result.stderr)

    def _run(self, command, env=None, halted_on_fail=False, save2mem=False):
        """
        运行测试命令
        """
        logger.debug(command)

        def add2var(field, asfield=None, src_conf=None):
            # 如果定义了导出环境变量的命名格式,则需要转换一次变量名称
            if not asfield:
                asfield = field
            if src_conf is None:
                if field == 'result_path':
                    self.config.var[asfield] = self._get_current_result_path()
                    return self
                if self.config.get(field):
                    self.config.var[asfield] = self.config[field]
                if field in self.config.var:
                    pass
                elif field == 'nr_cpu':
                    self.config.var[asfield] = str(cpu_count())
                elif field == 'memory':
                    self.config.var[asfield] = get_memory()
            else:
                self.config.var[asfield] = src_conf[field]

            return self

        if env is None:
            env = copy.deepcopy(self.config.env)

        # add2var('tone_path', 'TONE_ROOT')
        # add2var('name', 'TONE_BM')
        # add2var('result_path', 'TONE_RESULT_PATH')
        # add2var('cache_path', 'TONE_CACHE_PATH')
        # add2var('build_path', 'TONE_BUILD_PATH')
        # add2var('source_path', 'TONE_SOURCE_PATH')
        # add2var('run_path', 'TONE_RUN_PATH')

        add2var('result_path', 'TONE_CURRENT_RESULT_DIR')
        add2var('nr_cpu')
        add2var('memory')

        env.update(self.config.var)

        self.config.env = filter_env()
        self.dump_env()
        self.dump_env(envfn='var.sh', env=self.config.var)

        if self.config.current_scenaria and\
                self.config.stoptime:
            timeout = self.config.stoptime - time.time()
        else:
            timeout = None
        with TestCommand(
            command,
            env=env,
            timeout=timeout,
            logpath=self._get_current_result_path(),
            save2mem=save2mem,
            is_quiet=self.config.is_quiet
        ) as executor:
            self.result = executor.run()
        if halted_on_fail:
            if self.result.exit_status != 0:
                logger.error("Exist for result fail")
                sys.exit(1)

    def _generate_command(self, command, ext=None):
        """
        生成测试命令的封装
        """
        if command in ['fetch', 'install', 'uninstall']:
            os.chdir(self.config.cache_path)
            logger.debug("Working Path: " + self.config.cache_path)
        elif command in ['setup', 'teardown', 'run', 'parse']:
            os.chdir(self.config.run_path)
            logger.debug("Working Path: " + self.config.run_path)
        if not ext:
            ext = command
        rv = ". {}/common.sh; . {}/{}; [ X`type -t {}` == X'function' ] && {}"
        return rv.format(
            self.config.bash_lib_path,
            self.config.source_path,
            self.config.shell_map[command], command, ext)

    def _reformat2json(self):
        """
        解析输出结果转换为json格式
        """
        values = []
        is_function_test = self.config.category == 'functional'
        pattern = {
            # True: re.compile(r'^(\S+):\s(pass|fail|skip|warning)', re.IGNORECASE),
            # False: re.compile(r'^(\S+):\s([1-9]\d*\.\d*|0\.\d*|[1-9]\d*)\s*(\S*)\s*$'),
            # 修改正则，使得k值中允许存在空格
            True: re.compile(r'^((?:\S+ )*\S+):\s(pass|fail|skip|warning)', re.IGNORECASE),
            False: re.compile(r'^((?:\S+ )*\S+):\s([1-9]\d*\.\d*|0\.\d*|[1-9]\d*|0)\s*(\S*)\s*$'),
        }
        for line in self.result.stdout.split('\n'):
            if line.find(':') < 0:
                continue
            result = re.search(pattern[is_function_test], line)
            if not result:
                print("Skip line: {}".format(line))
                continue
            if is_function_test:
                k, v = result.groups()
                _result = {'testcase': k, 'value': v.title()}
                if "expect" in line:
                    for l in line.split(",")[1:]:
                        if "expect" in l:
                            _result["expect"] = l.split(":")[1].strip()
                        if "current" in l:
                            _result["current"] = l.split(":")[1].strip()
                values.append(_result)
            else:
                k, v, unit = result.groups()
                value = {
                    "value": v,
                    "unit": unit,
                }
                values.append({
                    "testcase": k,
                    "value": value,
                })
        with open(self.parselog, 'w+') as fh:
            json.dump(values, fh, indent=2)

    def fetch(self):
        """
        执行download函数
        """
        command = self._generate_command('fetch')
        command = os.path.join(self.config.tone_bin_path, 'fetch_test.sh')
        # create_link_force(
        #     command,
        #     os.path.join(
        #         self.config.result_path,
        #         'fetch_test.sh'))
        copy_file(command, os.path.join(self.config.result_path,'fetch_test.sh'))

        command = "{} {}".format(command, self.config.name)
        self._run(command, halted_on_fail=True)

    def install(self):
        """
        执行install函数
        """
        create_path(self.config.build_path)
        command = self._generate_command('install')
        command = os.path.join(self.config.tone_bin_path, 'install_test.sh')
        # create_link_force(
        #     command,
        #     os.path.join(
        #         self.config.result_path,
        #         'install_test.sh'))
        copy_file(command, os.path.join(self.config.result_path, 'install_test.sh'))

        command = "{} {}".format(command, self.config.name)
        self._run(command, halted_on_fail=True)
        if self.result and self.result.exit_status == 0:
            if not self.config.keep_build:
                destroy_path(self.config.build_path)

    def run(self):
        """
        执行run函数
        """
        command = self._generate_command('run')
        converter = ParamConvert()
        for key in converter:
            if converter[key].generate():
                logger.debug(converter[key].generate())
            else:
                logger.debug("Not need to generate code")
        command = os.path.join(self.config.tone_bin_path, 'run_test.sh')
        # create_link_force(
        #     command,
        #     os.path.join(self.config.result_runtime_path, 'run_test.sh'))
        copy_file(command, os.path.join(self.config.result_runtime_path, 'run_test.sh'))
        command = "{} {}".format(command, self.config.name)
        self._run(command, halted_on_fail=False)
        self._save_runstatus()

    def _save_runstatus(self):
        run_status = {
            'exit_status': self.result.exit_status,
        }
        test_status = os.path.join(
            self.config.result_runtime_path, 'runstatus.json'
        )
        with open(test_status, 'w') as fh:
            json.dump(run_status, fh, indent=2)

    def _get_run_status(self):
        status = 'pass'
        for d in self._get_result_dirs():
            test_status = os.path.join(d, 'runstatus.json')
            if not os.path.isfile(test_status):
                continue
            with open(test_status, 'r') as fp:
                tmp = json.load(fp)
            if 'exit_status' in tmp:
                if tmp['exit_status'] != 0:
                    status = 'fail'
                    break
        return status

    def collect_testinfo(self):
        """
        收集测试机器信息
        """
        command = os.path.join(self.config.tone_bin_path, 'testinfo.sh')
        # create_link_force(
        #     command,
        #     os.path.join(
        #         self.config.result_runtime_path,
        #         'testinfo.sh'))
        copy_file(command, os.path.join(self.config.result_runtime_path, 'testinfo.sh'))

        command = "{} {}".format(command, self.config.name)
        self._run(command, halted_on_fail=False)

    def parse(self):
        """
        执行parse函数
        """
        command = self._generate_command(
            "parse", "cat {} | parse".format(self.logfile))
        command = os.path.join(self.config.tone_bin_path, 'parse_test.sh')
        # create_link_force(
        #     command,
        #     os.path.join(self.config.result_runtime_path, 'parse_test.sh'))
        copy_file(command, os.path.join(self.config.result_runtime_path, 'parse_test.sh'))
        command = "{} {}".format(command, self.config.name)
        self._run(command, save2mem=True)
        self._reformat2json()

    def uninstall(self):
        """
        执行uninstall函数
        """
        command = self._generate_command('uninstall')
        command = os.path.join(self.config.tone_bin_path, 'uninstall_test.sh')
        # create_link_force(
        #     command,
        #     os.path.join(self.config.result_path, 'uninstall_test.sh'))
        copy_file(command, os.path.join(self.config.result_path, 'uninstall_test.sh'))
        command = "{} {}".format(command, self.config.name)
        self._run(command)

    def dump_env(self, envfn='env.sh', env=None):
        """
        保存测试运行的环境变量
        """
        if env is None:
            env = self.config.env
        envfile = os.path.join(self._get_current_result_path(), envfn)
        reg = re.compile(r'[\s,=]+')
        with open(envfile, 'w') as fh:
            for k, v in env.items():
                if reg.search(v):
                    fh.write("export {}=\"{}\"\n".format(k, v))
                else:
                    fh.write("export {}={}\n".format(k, v))

    def dump_system_info(self, ext='pre'):
        """
        收集测试中的系统信息
        """
        logger.debug("System info")
        logger.debug(self.config)
        result_path = os.path.join(
            self._get_current_result_path(), 'sysinfo', ext)
        os.makedirs(result_path)
        env = self.config.env
        for cmd in self.info_commands:
            logger.debug("Command: {}".format(cmd))
            outputlog = os.path.join(result_path, cmd.replace(" ", "_"))
            command = "{} > {} 2>&1".format(cmd, outputlog)
            task = ExecCmd(command, env=env)
            result = task.run()
            logger.debug(result.exit_status)

    def _get_result_dirs(self):
        return glob.glob(
            os.path.join(self.config.result_scenaria_path, '[0-9]*'))

    def static(self):
        """
        统计测试的统计信息，从测试结果文件result.json中计算均值和变异系数
        """
        logger.debug("Static function run...")
        data = {}
        testorder = []
        avg = {}
        cv = {}
        source = []
        is_function_test = self.config.category == 'functional'
        testinfo = {
            'suite': self.config.name,
            'start': self.config.starttime,
            'end': time.time(),
            'params': self.config.current_scenaria['env'],
            'hostname': os.uname(),
            'os_version': read_distribute(),
            'category': self.config.category,
            'testconf': self.config.current_scenaria['scenaria'],
        }
        _result_dirs = sorted(self._get_result_dirs())
        for d in _result_dirs:
            dirname = os.path.basename(d)
            if dirname.isdigit():
                resultfile = os.path.join(d, 'result.json')
                if not os.path.isfile(resultfile):
                    continue
                source.append(resultfile)
                with open(resultfile, 'r') as fp:
                    tmp = json.load(fp)
                    for item in tmp:
                        k = item['testcase']
                        v = item['value']
                        if k not in data:
                            if is_function_test:
                                data[k] = {}
                                data[k]["value"] = []
                                data[k]["expect"] = []
                                data[k]["current"] = []
                            else:
                                data[k] = []
                            testorder.append(k)
                        try:
                            if is_function_test:
                                data[k]["value"].append(v)
                                if "expect" in item:
                                    data[k]["expect"].append(item["expect"])
                                else:
                                    data[k]["expect"].append("")
                                if "current" in item:
                                    data[k]["current"].append(item["current"])
                                else:
                                    data[k]["current"].append("")
                            else:
                                data[k].append(v)
                        except Exception as e:
                            logger.error(
                                "{} contain wrong data!".format(
                                    resultfile))
                            raise e
        logger.debug(data)
        results = {}
        for k in testorder:
            if is_function_test:
                avg[k] = len([1 for v in data[k]["value"] if v == 'Pass']) / \
                    len(data[k]["value"])
            else:
                raw_data = [float(d["value"]) for d in data[k]]
                avg = self._avg(raw_data)
                # avoid divison by zero
                if avg > -0.000000001 and avg < 0.000000001:
                    cv = 0
                else:
                    cv = self._stdev(raw_data, avg) / avg
                results[k] = {
                    "average": avg,
                    "cv": cv,
                    "unit": data[k][0]["unit"],
                    "matrix": raw_data,
                }
        logger.debug(avg)
        logger.debug(cv)
        run_status = self._get_run_status()
        if is_function_test:
            results_array = []
            for test in testorder:
                recorder = {}
                recorder['matrix'] = data[test]["value"]
                recorder['expect'] = data[test]["expect"]
                recorder['current'] = data[test]["current"]
                recorder['testcase'] = test
                results_array.append(recorder)
            if len(results_array) == 0:
                run_status = 'fail'
            statistic = {
                'testinfo': testinfo,
                'results': results_array,
                'source': source,
                'status': run_status,
            }
        else:
            results_array = []
            for index in testorder:
                recorder = results[index]
                recorder['metric'] = index
                results_array.append(recorder)
            if len(results_array) == 0:
                run_status = 'fail'
            statistic = {
                'testinfo': testinfo,
                # 'matrix': data,
                # 'average': avg,
                # coefficient of variation
                # 'cv': cv,
                'results': results_array,
                'source': source,
                'status': run_status,
            }
        with open(
                os.path.join(
                    self.config.result_scenaria_path,
                    'statistic.json'),
                'w+') as fp:
            json.dump(statistic, fp, indent=2, sort_keys=True)
        # Copy stdout.log
        current_stdout = os.path.join(self._get_current_result_path(), "stdout.log")
        _scenaria_name = self.config.current_scenaria["scenaria"].split(":")[0]
        scenaria_stdout = os.path.join(self.config.result_scenaria_path, _scenaria_name + ".stdout.log")
        copy_file(current_stdout, scenaria_stdout)

    def _avg(self, data):
        return float(sum(data)) / len(data)

    def _stdev(self, data, m):
        stdev = float(sum((x - m) ** 2 for x in data))
        stdev = (stdev / len(data)) ** 0.5
        return stdev


class TestEnvInit(object):
    """
    这个类用于维护测试环境
    """

    def __init__(self, config=None):
        # Use TestConfig
        self.config = config

    def create_cache_path(self):
        cache_path = self.config.cache_path
        create_path(cache_path)

    def create_run_path(self):
        run_path = self.config.run_path
        create_path_force(run_path)

    def create_result_path(self):
        """
        创建测试结果目录
        """
        result_path = self.config.result_path
        create_path(result_path)

    def create_result_runtime_path(self):
        """
        创建当前运行次数的结果目录
        """
        # pdb.set_trace()
        # 在result_path后添加scenaria组成新的result_path
        tmp = self.config.current_scenaria['scenaria'].split(':')
        # process the result path
        _test_params = []
        for i in tmp[-1].split(","):
            if "=" in i:
                _test_params.append(i.split("=")[-1])
            else:
                _test_params.append(i)
        _result_path = "-".join((_test_params))

        result_path = os.path.join(self.config.result_path, _result_path)
        maxstep = 0
        for p in glob.glob(os.path.join(result_path, '[0-9]*')):
            name = os.path.basename(p)
            if name.isdigit() and maxstep < int(name):
                maxstep = int(name)
        maxstep = maxstep + 1
        result_runtime_path = os.path.join(result_path, str(maxstep))
        create_path(result_runtime_path)
        self.config.result_runtime_path = result_runtime_path
        self.config.result_scenaria_path = result_path

    def set_env(self):
        """
        设置环境变量，读入测试套件的环境设置
        """
        env = {}
        for k, v in os.environ.items():
            env[k] = v
        # TODO: 读取配置参数
        logger.debug('-' * 80)
        if self.config.subcommand == 'fetch':
            pass

        self.config.env = env

    def test_info(self):
        suite = self.config.name
        setting = ""
        if self.config.current_scenaria:
            for i in self.config.current_scenaria.get('env'):
                setting += "\t{}={}".format(
                    i, self.config.current_scenaria['env'][i])
        if self.config.subcommand in ('fetch', 'install', 'uninstall'):
            run_path = self.config.cache_path
            result_path = os.path.join(
                self.config.result_path,
                'stdout.log')
            conf = "Not needed at current stage"
            if self.config.subcommand in ('install', 'uninstall'):
                run_path = self.config.run_path
        else:
            run_path = self.config.run_path
            result_path = os.path.join(
                self.config.result_runtime_path,
                'stdout.log')
            conf = self.config.current_scenaria['scenaria']
        cache_path = self.config.cache_path
        print("""
Test Suite   : {}
Test Conf    : {}
Test Setting : {}
Test Run Path: {}
Test Result  : {}
Test Cache   : {}
        """.format(suite, conf, setting, run_path, result_path, cache_path))

    def read_kconfig(self):
        kconfig_file = os.path.join('/boot', 'config-' + os.uname()[2])
        kconfig = {}
        with open(kconfig_file, 'r') as fh:
            lines = fh.readlines()
            for line in lines:
                line = line.strip()
                if len(line) == 0 or line[0] == '#':
                    next
                kconfig[line] = 1
        return kconfig


def filter_env():
    env = {}
    for k, v in os.environ.items():
        if k == 'PWD':
            env[k] = os.getcwd()
        env[k] = v
    return env


def exists_scenaria_file(testdir, scenaria):
    """
    判断测试套件的配置文件是否存在
    """
    fn = os.path.join(testdir, scenaria)
    if os.path.isfile(fn):
        return True
    else:
        return False


def mix_varables(variables):
    """
    将读入的变量值进行组合
    return 组合后的键值对数组
    """
    keys = sorted(variables.keys())
    params = []
    for k in keys:
        if isinstance(variables[k], list):
            params.append(variables[k])
        else:
            params.append([variables[k]])
    logger.debug(params)
    params = itertools.product(*params)

    values = []
    for p in params:
        val = {}
        for i in range(len(keys)):
            val[keys[i]] = p[i]
        values.append(val)
    return values


def read_distribute():
    distribute_files = [
        '/etc/alinux-release',
        '/etc/openEuler-release',
        '/etc/kylin-release',
        '/etc/redhat-release',
        '/etc/os-release',
    ]

    for distribute in distribute_files:
        if os.path.isfile(distribute):
            with open(distribute) as fh:
                lines = fh.readlines()
            break
    if lines:
        return "".join(lines)
    else:
        return ""
